{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.1 模型构造"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "class MLP(nn.Module):\n",
    "    # 声明带有模型参数的层，这里声明了两个全连接层\n",
    "    def __init__(self, **kwargs):\n",
    "        # 调用MLP父类Module的构造函数来进行初始化\n",
    "        # 这样在构造实例时还可以指定其他函数参数\n",
    "        # 如3.2节将介绍的模型参数params\n",
    "        super(MLP, self).__init__(**kwargs)\n",
    "        self.hidden = nn.Linear(784, 256)\n",
    "        self.act = nn.ReLU()\n",
    "        self.output = nn.Linear(256, 10)\n",
    "    # 定义模型的前向计算，即如何根据输入x计算返回所需要的模型输出\n",
    "    def forward(self, x):\n",
    "        a = self.act(self.hidden(x))\n",
    "        return self.output(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP(\n",
      "  (hidden): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (act): ReLU()\n",
      "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1166,  0.3709,  0.1625, -0.0634,  0.0378, -0.0738, -0.3726,  0.2768,\n",
       "         -0.0705, -0.1382],\n",
       "        [-0.0873,  0.2201, -0.0406, -0.0802, -0.0960,  0.2466, -0.3343,  0.0729,\n",
       "          0.0937, -0.4173]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.randn(2, 784)\n",
    "net = MLP()\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MySequential(nn.Module):\n",
    "    from collections import OrderedDict\n",
    "    def __init__(self, *args):\n",
    "        super(MySequential, self).__init__()\n",
    "        # 如果传入的是一个OrderedDict\n",
    "        if len(args)==1 and isinstance(args[0], OrderedDict):\n",
    "            for key, module in args[0].items():\n",
    "                # add_module方法会将module添加进self._modules(一个OrderedDict)\n",
    "                self.add_module(key, module)\n",
    "        # 传入的是一些Module\n",
    "        else:\n",
    "            for idx, module in enumerate(args):\n",
    "                self.add_module(str(idx), module)\n",
    "    def forward(self, input):\n",
    "        # self._modules返回一个OrderedDict，保证会按照成员添加时的顺序遍历成员\n",
    "        for module in self._modules.values():\n",
    "            input = module(input)\n",
    "        return input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MySequential(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1801, -0.2176,  0.0798, -0.1621, -0.0156, -0.0083,  0.2796,  0.2647,\n",
       "          0.0342,  0.3056],\n",
       "        [ 0.1345, -0.5100, -0.0496,  0.3354,  0.0183, -0.5262, -0.3022, -0.0921,\n",
       "          0.2138,  0.2389]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MySequential(\n",
    "    nn.Linear(784, 256), \n",
    "    nn.ReLU(), \n",
    "    nn.Linear(256, 10)\n",
    ")\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FancyMLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(FancyMLP, self).__init__(**kwargs)\n",
    "        # 不可训练参数（常数参数）\n",
    "        self.rand_weight = torch.rand((20, 20), requires_grad=False)\n",
    "        self.linear = nn.Linear(20, 20)\n",
    "    def forward(self, x):\n",
    "        x = self.linear(x)\n",
    "        # 使用创建的常数参数，以及nn.functional中的relu函数和mm函数\n",
    "        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)\n",
    "        # 复用全连接层，等价于两个全连接层共享参数\n",
    "        x = self.linear(x)\n",
    "        # 控制流，这里我们需要调用item函数来返回标量进行比较\n",
    "        while x.norm().item() > 1:\n",
    "            x /= 2\n",
    "        if x.norm().item() < 0.8:\n",
    "            x *= 10\n",
    "        return x.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FancyMLP(\n",
      "  (linear): Linear(in_features=20, out_features=20, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(-2.4393, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.rand(2, 20)\n",
    "net = FancyMLP()\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): NestMLP(\n",
      "    (net): Sequential(\n",
      "      (0): Linear(in_features=40, out_features=30, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=30, out_features=20, bias=True)\n",
      "  (2): FancyMLP(\n",
      "    (linear): Linear(in_features=20, out_features=20, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(-4.2659, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(NestMLP, self).__init__(**kwargs)\n",
    "        self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU())\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())\n",
    "X = torch.rand(2, 40)\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2 模型参数的访问、初始化和共享"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=20, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(20, 256), \n",
    "                   nn.ReLU(), \n",
    "                   nn.Linear(256, 10))\n",
    "print(net)\n",
    "X = torch.rand(2, 20)\n",
    "Y = net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight torch.Size([256, 20]) <class 'torch.nn.parameter.Parameter'>\n",
      "bias torch.Size([256]) <class 'torch.nn.parameter.Parameter'>\n"
     ]
    }
   ],
   "source": [
    "for name, param in net[0].named_parameters():\n",
    "    print(name, param.size(), type(param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Tensor, torch.Size([256, 20]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight_0 = list(net[0].parameters())[0]\n",
    "type(weight_0.data[0][0]), weight_0.data.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "param = dict(net[0].named_parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.nn.parameter.Parameter, torch.Size([256, 20]))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(param['weight']), param['weight'].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.nn.parameter.Parameter, torch.Size([256]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(param['bias']), param['bias'].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1043, -0.1395,  0.0299,  ...,  0.1190,  0.1244,  0.1795],\n",
       "        [-0.0209, -0.2027,  0.1324,  ...,  0.0212, -0.0165, -0.0586],\n",
       "        [ 0.1829, -0.0290,  0.0041,  ..., -0.1630,  0.1055, -0.1026],\n",
       "        ...,\n",
       "        [-0.1275,  0.2187, -0.0203,  ...,  0.0015, -0.1478, -0.0467],\n",
       "        [-0.0304,  0.1645, -0.1152,  ..., -0.1894,  0.1787, -0.1475],\n",
       "        [ 0.1251, -0.0461,  0.1948,  ..., -0.1353,  0.0808,  0.1132]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight_0.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(weight_0.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.0488,  0.0115,  0.0571, -0.0447,  0.0059,  0.0128,  0.0396,  0.0289,\n",
       "         0.0044,  0.0107])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bias_1 = list(net[2].parameters())[1]\n",
    "bias_1.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=20, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.weight tensor([[ 0.0061,  0.0113, -0.0155,  ...,  0.0133, -0.0025, -0.0187],\n",
      "        [-0.0063,  0.0111,  0.0108,  ..., -0.0122, -0.0177,  0.0070],\n",
      "        [-0.0170, -0.0057,  0.0131,  ..., -0.0080,  0.0020, -0.0108],\n",
      "        ...,\n",
      "        [-0.0081,  0.0010, -0.0115,  ..., -0.0056, -0.0139,  0.0035],\n",
      "        [ 0.0054,  0.0214, -0.0071,  ...,  0.0031,  0.0036,  0.0052],\n",
      "        [ 0.0078, -0.0092,  0.0023,  ...,  0.0007, -0.0032,  0.0060]])\n",
      "2.weight tensor([[ 4.4451e-03,  2.8794e-03,  1.5734e-02,  ..., -3.9530e-03,\n",
      "          5.8971e-03, -7.4768e-03],\n",
      "        [-2.2786e-03,  9.6592e-03, -3.5578e-03,  ..., -1.5789e-02,\n",
      "         -4.1393e-03, -2.8530e-05],\n",
      "        [ 3.6222e-03, -3.7768e-03, -2.5516e-03,  ...,  9.3914e-03,\n",
      "         -1.3788e-03, -1.5163e-02],\n",
      "        ...,\n",
      "        [ 6.2915e-03, -6.8922e-04,  4.3681e-03,  ..., -1.0601e-02,\n",
      "          1.3362e-02,  8.1113e-05],\n",
      "        [-6.7424e-03, -2.1240e-02, -1.2532e-02,  ...,  5.3504e-03,\n",
      "          1.3282e-04, -1.0524e-02],\n",
      "        [-1.0599e-02,  4.7868e-05, -6.8219e-03,  ...,  2.4288e-03,\n",
      "          9.2932e-04,  5.2038e-03]])\n"
     ]
    }
   ],
   "source": [
    "from torch.nn import init\n",
    "for name, param in net.named_parameters():\n",
    "    if 'weight' in name:\n",
    "        init.normal_(param, mean=0, std=0.01)\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n",
      "2.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    if 'bias' in name:\n",
    "        init.constant_(param, val=0)\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.0490,  0.0922,  0.0252,  0.0345, -0.0375,  0.0889,  0.0686,  0.0595,\n",
      "         0.0388, -0.0629, -0.0648, -0.0676,  0.0073, -0.0999,  0.0053, -0.0298,\n",
      "         0.1312, -0.0437, -0.0759,  0.1100])\n",
      "tensor([[42., 42., 42.,  ..., 42., 42., 42.],\n",
      "        [42., 42., 42.,  ..., 42., 42., 42.],\n",
      "        [42., 42., 42.,  ..., 42., 42., 42.],\n",
      "        ...,\n",
      "        [42., 42., 42.,  ..., 42., 42., 42.],\n",
      "        [42., 42., 42.,  ..., 42., 42., 42.],\n",
      "        [42., 42., 42.,  ..., 42., 42., 42.]])\n"
     ]
    }
   ],
   "source": [
    "def xavier(m):\n",
    "    if type(m)==nn.Linear:\n",
    "        torch.nn.init.xavier_uniform_(m.weight)\n",
    "def init_42(m):\n",
    "    if type(m)==nn.Linear:\n",
    "        torch.nn.init.constant_(m.weight, 42)\n",
    "net[0].apply(xavier)\n",
    "net[2].apply(init_42)\n",
    "print(net[0].weight.data[0])\n",
    "print(net[2].weight.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.weight tensor([[ 0.0000, -0.0000, -9.0268,  ..., -0.0000, -6.2936, -6.7501],\n",
      "        [ 0.0000, -7.2599, -6.1810,  ..., -0.0000,  0.0000,  0.0000],\n",
      "        [ 6.3452, -6.5682, -9.5816,  ..., -7.2323,  0.0000,  7.0041],\n",
      "        ...,\n",
      "        [-7.9436,  5.7349, -0.0000,  ...,  0.0000,  0.0000, -0.0000],\n",
      "        [ 0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000,  0.0000],\n",
      "        [ 0.0000,  7.5755,  9.9136,  ...,  7.4412,  0.0000,  0.0000]])\n",
      "2.weight tensor([[ 0.0000, -0.0000,  7.1459,  ..., -7.5689,  9.1732,  0.0000],\n",
      "        [ 0.0000,  0.0000, -5.6028,  ..., -5.5531, -9.9678, -9.8330],\n",
      "        [ 0.0000,  0.0000,  0.0000,  ...,  7.7888,  0.0000, -7.8483],\n",
      "        ...,\n",
      "        [ 8.7079, -0.0000,  0.0000,  ..., -0.0000, -9.1407,  0.0000],\n",
      "        [-7.5157,  0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
      "        [ 7.9777,  0.0000,  0.0000,  ...,  9.2163, -0.0000, -0.0000]])\n"
     ]
    }
   ],
   "source": [
    "def init_weight_(tensor):\n",
    "    with torch.no_grad():\n",
    "        tensor.uniform_(-10, 10)\n",
    "        tensor *= (tensor.abs() >= 5).float()\n",
    "for name, param in net.named_parameters():\n",
    "    if 'weight' in name:\n",
    "        init_weight_(param)\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
      "        1., 1., 1., 1.])\n",
      "2.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    if 'bias' in name:\n",
    "        param.data += 1\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=1, out_features=1, bias=False)\n",
      "  (1): Linear(in_features=1, out_features=1, bias=False)\n",
      ")\n",
      "0.weight tensor([[3.]])\n"
     ]
    }
   ],
   "source": [
    "linear = nn.Linear(1, 1, bias=False)\n",
    "net = nn.Sequential(linear, linear)\n",
    "print(net)\n",
    "for name, param in net.named_parameters():\n",
    "    init.constant_(param, val=3)\n",
    "    print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(id(net[0]) == id(net[1]))\n",
    "print(id(net[0].weight) == id(net[1].weight))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(9., grad_fn=<SumBackward0>)\n",
      "tensor([[6.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(1, 1)\n",
    "y = net(x).sum()\n",
    "print(y)\n",
    "y.backward()\n",
    "# 单次梯度是3，两次所以就是6\n",
    "print(net[0].weight.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CenteredLayer(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(CenteredLayer, self).__init__(**kwargs)\n",
    "    def forward(self, x):\n",
    "        return x - x.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = CenteredLayer()\n",
    "layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-9.313225746154785e-10"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = net(torch.rand(4, 8))\n",
    "y.mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyDense(\n",
      "  (params): ParameterList(\n",
      "      (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class MyDense(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyDense, self).__init__()\n",
    "        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])\n",
    "        self.params.append(nn.Parameter(torch.randn(4, 1)))\n",
    "    def forward(self, x):\n",
    "        for i in range(len(self.params)):\n",
    "            x = torch.mm(x, self.params[i])\n",
    "        return x\n",
    "net = MyDense()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyDictDense(\n",
      "  (params): ParameterDict(\n",
      "      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class MyDictDense(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyDictDense, self).__init__()\n",
    "        self.params = nn.ParameterDict({\n",
    "            'linear1': nn.Parameter(torch.randn(4, 4)), \n",
    "            'linear2': nn.Parameter(torch.randn(4, 1))\n",
    "        })\n",
    "        # 新增\n",
    "        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})\n",
    "    def forward(self, x, choice='linear1'):\n",
    "        return torch.mm(x, self.params[choice])\n",
    "net = MyDictDense()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-2.3688, -0.6960, -0.5341,  2.3085]], grad_fn=<MmBackward>)\n",
      "tensor([[-1.9920]], grad_fn=<MmBackward>)\n",
      "tensor([[0.4029, 0.7226]], grad_fn=<MmBackward>)\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(1, 4)\n",
    "print(net(x, 'linear1'))\n",
    "print(net(x, 'linear2'))\n",
    "print(net(x, 'linear3'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): MyDictDense(\n",
      "    (params): ParameterDict(\n",
      "        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
      "    )\n",
      "  )\n",
      "  (1): MyDense(\n",
      "    (params): ParameterList(\n",
      "        (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
      "        (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
      "    )\n",
      "  )\n",
      ")\n",
      "tensor([[3.5523]], grad_fn=<MmBackward>)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    MyDictDense(), \n",
    "    MyDense()\n",
    ")\n",
    "print(net)\n",
    "print(net(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.ones(3)\n",
    "torch.save(x, 'x.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1.])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2 = torch.load('x.pt')\n",
    "x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.zeros(4)\n",
    "torch.save([x, y], 'xy.pt')\n",
    "xy_list = torch.load('xy.pt')\n",
    "xy_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.save({'x': x, 'y': y}, 'xy_dict.pt')\n",
    "xy = torch.load('xy_dict.pt')\n",
    "xy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type MLP. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    }
   ],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MLP, self).__init__()\n",
    "        self.hidden = nn.Linear(3, 2)\n",
    "        self.act = nn.ReLU()\n",
    "        self.output = nn.Linear(2, 1)\n",
    "    def forward(self, x):\n",
    "        a = self.act(self.hidden(x))\n",
    "        return self.output(a)\n",
    "net = MLP()\n",
    "torch.save(net, 'model.bin')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "net2 = torch.load('model.bin')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1], dtype=torch.uint8)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y2 = net2(x)\n",
    "Y = net(x)\n",
    "Y2 == Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cpu'),\n",
       " <torch.cuda.device at 0x7f2cba460e10>,\n",
       " <torch.cuda.device at 0x7f2d792a1710>)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.device('cpu'), torch.cuda.device('cuda'), torch.cuda.device('cuda:1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([1, 2, 3])\n",
    "x.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 2, 3], device='cuda:0')"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.tensor([1, 2, 3], device=torch.device('cuda'))\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0251, 0.4416, 0.1046],\n",
       "        [0.5686, 0.7297, 0.8858]], device='cuda:1')"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "B = torch.rand(2, 3, device=torch.device('cuda:1'))\n",
    "B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1, 2, 3])\n",
      "tensor([1, 2, 3], device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "Z = x.cuda(1)\n",
    "print(x)\n",
    "print(Z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "expected backend CUDA and dtype Long but got backend CPU and dtype Long",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-49-62ad4af72e42>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mZ\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m: expected backend CUDA and dtype Long but got backend CPU and dtype Long"
     ]
    }
   ],
   "source": [
    "Z + x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 20.0855, 109.1963, 445.2395], device='cuda:1')"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.exp((Z+2).float())*x.float().cuda(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Linear(3, 1)\n",
    "net.cuda()\n",
    "net.weight.data.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.4004]], device='cuda:0', grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net(Z.view(1, 3).float().cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2655, -0.1501,  0.5458]], device='cuda:0')"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:PyTorch] *",
   "language": "python",
   "name": "conda-env-PyTorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
